import csv
import os
import glob
import re

# 去掉stm中csv文件caption中的特殊token，保存到csv_without_special_tokens
# 获取路径
csv_folder = '/root/autodl-tmp/TEDspliter/train/stm'
output_folder = '/root/autodl-tmp/TEDspliter/train/stm_without_special_tokens'
os.makedirs(output_folder, exist_ok=True)

# 获取所有 CSV 文件
csv_files = glob.glob(os.path.join(csv_folder, '*.csv'))

# 处理体
for csv_file in csv_files:
    basename = os.path.basename(csv_file)
    output_file = os.path.join(output_folder, basename)

    # 打开输入文件和输出文件
    with open(csv_file, 'r') as infile, open(output_file, 'w', newline='') as outfile:
        # 创建 csv 写入器
        writer = csv.writer(outfile)
        # 创建 csv 读取器
        reader = csv.DictReader(infile)
        # 写入标题行
        writer.writerow(['xmin', 'xmax', 'caption'])
        
        for i, row in enumerate(reader):
            # 读入数据
            xmin = row['xmin']
            xmax = row['xmax']
            caption = row['caption']
            # 处理captio:去掉括号和空格
            caption = re.sub(r'\{[^}]*\}|\([^)]*\)|\<[^>]*\>', '', caption)
            caption = re.sub(r'\s{2,}', ' ', caption)
            # 写入一行到输出文件
            writer.writerow([xmin, xmax, caption]) 